from data import Lattice, Catalogue
from utils import plotting
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import numpy as np
from random import shuffle
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly
plotly.offline.init_notebook_mode()
Import unit cell catalogue from PNAS paper by Lumpe, T. S. and Stankovic, T. (2020)
https://www.pnas.org/doi/10.1073/pnas.2003504118
Catalogue can be downloaded from
https://doi.org/10.3929/ethz-b-000457598
cat = Catalogue.from_file('./Unit_Cell_Catalog.txt', indexing=1)
print(cat)
Unit cell catalogue with 17222 entries
First filter out lattices where nodes are closer than 5% of unit cell size and plot some lattices that we are discarding.
selected = []
print(f'Catalogue before: {len(cat)}')
k = 0
df_data = {}
discarded = []
pbar = tqdm(cat.names)
for name in pbar:
lat = Lattice(**cat.get_unit_cell(name))
distances, dist_indices = lat.calculate_node_distances('transformed')
if distances.min()<0.05:
discarded.append(name)
k += 1
else:
selected.append(name)
df_data.update({name:{
'min_dist':distances.min(),
'num_nodes':lat.num_nodes,
'num_edges':lat.num_edges
}
})
pbar.set_postfix({'Discarded':k})
print(f'Catalogue after: {len(selected)}')
df = pd.DataFrame(df_data).T
df = df.sort_values(by='min_dist')
df.describe()
Catalogue before: 17222
100%|██████████| 17222/17222 [00:28<00:00, 609.24it/s, Discarded=7445]
Catalogue after: 9777
| min_dist | num_nodes | num_edges | |
|---|---|---|---|
| count | 9777.000000 | 9777.000000 | 9777.000000 |
| mean | 0.149163 | 57.633221 | 78.552521 |
| std | 0.098145 | 45.435541 | 72.401143 |
| min | 0.050007 | 6.000000 | 6.000000 |
| 25% | 0.082500 | 30.000000 | 37.000000 |
| 50% | 0.121920 | 46.000000 | 60.000000 |
| 75% | 0.181611 | 71.000000 | 92.000000 |
| max | 1.000000 | 806.000000 | 1360.000000 |
ncols = 4; nrows = 2
fig = make_subplots(
rows=nrows, cols=ncols,
subplot_titles=['t' for _ in range(nrows*ncols)],
specs=[[{"type": "scatter3d"} for _ in range(ncols)] for _ in range(nrows)]
)
shuffle(discarded)
for k in range(ncols*nrows):
lat = Lattice(**cat.get_unit_cell(discarded[k]))
distances, dist_indices = lat.calculate_node_distances('transformed')
highlight_nodes = np.concatenate([ dist_indices[ind,:] for ind in np.flatnonzero(distances<0.05) ])
fig = plotting.plotly_unit_cell_3d(lat, 'transformed', fig=fig, subplot=dict(nrows=nrows, ncols=ncols, index=k), node_numbers=False, highlight_nodes=highlight_nodes)
fig.update_layout(width=1800,height=800)
fig.show()
Inspect some lattices that we are keeping. Plot lattices with shortest node distance.
fig = make_subplots(
rows=1, cols=4,
subplot_titles=['t' for _ in range(4)],
specs=[[{"type": "scatter3d"} for _ in range(4)]]
)
for j,name in enumerate(df.head(4).index):
lat = Lattice(**cat.get_unit_cell(name))
distances, indices = lat.calculate_node_distances('transformed')
indmin = np.argmin(distances)
fig = plotting.plotly_unit_cell_3d(
lat, repr='transformed',
fig=fig, subplot=dict(nrows=1, ncols=4, index=j),
highlight_nodes=[indices[indmin,0], indices[indmin,1]],
show_uc_box=True
)
fig.update_layout(width=2000, height=500)
fig.show()
Fix errors in lattices that we are keeping. Split edges by existing nodes and find and remove intersections between edges.
k = 0
skipped = 0
pbar = tqdm(selected)
selected = []
modified_cat = {}
for name in pbar:
lat = Lattice(**cat.get_unit_cell(name))
nodes_on_edges = lat.find_nodes_on_edges()
modifed = False
if nodes_on_edges:
lat.split_edges_by_points(nodes_on_edges)
modified = True
edge_intersections = lat.find_edge_intersections()
if edge_intersections:
lat.split_edges_by_points(edge_intersections)
modifed = True
# check nodal distances
distances, _ = lat.calculate_node_distances('reduced')
if distances.min()<0.05:
skipped += 1
continue
if modifed:
k += 1
if lat.find_nodes_on_edges() or lat.find_edge_intersections():
# check that succeeded
print(f'Lattice {name} failed')
break
modified_cat[name] = lat.print_lattice_lines()
pbar.set_postfix({'Modified lattices':k, 'Skipped':skipped})
cat = Catalogue.from_dict(modified_cat)
print(f'Catalogue size: {len(cat)}')
100%|██████████| 9777/9777 [03:09<00:00, 51.55it/s, Modified lattices=737, Skipped=106]
Catalogue size: 9671
Save the filtered catalogue to file
cat.to_file('filtered_cat.lat')
Check dataset
cat = Catalogue.from_file('./filtered_cat.lat', 0)
for name in tqdm(cat.names):
lat = Lattice(**cat.get_unit_cell(name))
distances, _ = lat.calculate_node_distances()
if distances.min()<0.05:
raise RuntimeError(f'Minimum nodal distance for lattice {name} is {distances.min()}')
if lat.find_nodes_on_edges():
raise RuntimeError(f'Lattice {name} has nodes on edges')
if lat.find_edge_intersections():
raise RuntimeError(f'Lattice {name} has edge intersections')
100%|██████████| 9671/9671 [02:36<00:00, 61.77it/s]
Plot examples
names = [name for name in cat.names]
shuffle(names)
fig = make_subplots(
rows=1, cols=4,
subplot_titles=['t' for _ in range(4)],
specs=[[{"type": "scatter3d"} for _ in range(4)]]
)
for j,name in enumerate(df.head(4).index):
lat = Lattice(**cat.get_unit_cell(name))
fig = plotting.plotly_unit_cell_3d(
lat, repr='transformed',
fig=fig, subplot=dict(nrows=1, ncols=4, index=j),
show_uc_box=True
)
fig.update_layout(width=2000, height=500)
fig.show()
Statistics
df_data = {}
for name in tqdm(cat.names):
lat = Lattice(**cat.get_unit_cell(name))
df_data.update({name:{
'num_nodes':lat.num_nodes,
'num_edges':lat.num_edges
}
})
df = pd.DataFrame(df_data).T
df.describe()
100%|██████████| 9671/9671 [00:02<00:00, 4754.02it/s]
| num_nodes | num_edges | |
|---|---|---|
| count | 9671.000000 | 9671.000000 |
| mean | 58.177438 | 79.823079 |
| std | 45.861697 | 73.673558 |
| min | 6.000000 | 6.000000 |
| 25% | 30.000000 | 38.000000 |
| 50% | 47.000000 | 60.000000 |
| 75% | 72.000000 | 95.000000 |
| max | 806.000000 | 1360.000000 |
fig = make_subplots(
rows=1, cols=2,
subplot_titles=("Number of nodes", "Number of edges", "Mean edge lengths", "Minimum nodal distance")
)
marker_dict = {'line':{'color':'black', 'width':0.5}}
fig.add_histogram(
x=df['num_nodes'], name='Nodes', row=1, col=1,
marker=marker_dict
)
fig.add_histogram(
x=df['num_edges'], name='Edges', row=1, col=2,
marker=marker_dict,
)
fig.update_layout(xaxis_range=[0,400])
fig.update_layout(xaxis2_range=[0,400])
fig.update_layout(title='Unit cell statistics')
fig.update_layout(height=400, width=1000, showlegend=False)
fig
newdata = dict()
clustered = []
nodes_on_edges_lat = []
splitting_edges = []
unmodified = 0
MIN_NODE_DIST = 0.1
pbar = trange(0,len(names))
maxtry = 0
written = 0
maxlat = ''
for j in pbar:
lattice = names[j]
lat = Lattice(**cat.get_unit_cell(lattice))
min_dist = lat.closest_node_distance()[0]
if min_dist<MIN_NODE_DIST:
continue
modified = False
nbef = lat.num_nodes
ebef = lat.num_edges
lat.collapse_nodes_onto_boundaries()
#
nodes_on_edges = lat.find_nodes_on_edges()
if nodes_on_edges:
modified = True
nodes_on_edges_lat.append(lattice)
lat.split_edges_by_nodes(nodes_on_edges)
if lat.find_nodes_on_edges():
print(f'{lattice} nodes on edges not fixed')
#
edge_int = lat.find_edge_intersections()
if len(edge_int)>0:
modified = True
splitting_edges.append(lattice)
lat.split_edges_at_intersections(edge_int)
min_dist = lat.closest_node_distance()[0]
if min_dist<MIN_NODE_DIST:
continue
edge_int = lat.find_edge_intersections()
if len(edge_int)>0:
print(f'{lattice} intersections not fixed')
eafter = lat.num_edges
nafter = lat.num_nodes
if not modified:
unmodified += 1
# create window
try:
wlat = lat.create_windowed()
except Exception:
try:
wlat = lat.create_windowed()
except Exception:
print(f'Lattice {lattice} failed')
continue
newdata[lattice] = wlat.print_lattice_lines()
written += 1
pbar.set_postfix(
clustered=len(clustered),
nodes_on_edges=len(nodes_on_edges_lat),
edge_intersections=len(splitting_edges),
unmodified=unmodified,
written=written,
refresh=False
)
print(f'Writing catalogue of {len(newdata)} lattices to file')
newcat = Catalogue.from_dict(newdata)
newcat.to_file('./catalogue_sparse_windowed.lat')